# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import random
import unittest

import numpy as np

import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.nn as nn
from paddle import Model, set_device
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.layers.utils import map_structure
from paddle.nn import (
    RNN,
    BeamSearchDecoder,
    Embedding,
    Layer,
    Linear,
    LSTMCell,
    SimpleRNNCell,
    dynamic_decode,
)
from paddle.static import InputSpec as Input

paddle.enable_static()


class PolicyGradient:
    """policy gradient"""

    def __init__(self, lr=None):
        self.lr = lr

    def learn(self, act_prob, action, reward, length=None):
        """
        update policy model self.model with policy gradient algorithm
        """
        self.reward = paddle.static.py_func(
            func=reward_func, x=[action, length], out=reward
        )
        neg_log_prob = paddle.nn.functional.cross_entropy(
            act_prob, action, reduction='none', use_softmax=False
        )
        cost = neg_log_prob * reward
        cost = (
            (paddle.sum(cost) / paddle.sum(length))
            if length is not None
            else paddle.mean(cost)
        )
        optimizer = fluid.optimizer.Adam(self.lr)
        optimizer.minimize(cost)
        return cost


def reward_func(samples, sample_length):
    """toy reward"""

    def discount_reward(reward, sequence_length, discount=1.0):
        return discount_reward_1d(reward, sequence_length, discount)

    def discount_reward_1d(reward, sequence_length, discount=1.0, dtype=None):
        if sequence_length is None:
            raise ValueError(
                'sequence_length must not be `None` for 1D reward.'
            )
        reward = np.array(reward)
        sequence_length = np.array(sequence_length)
        batch_size = reward.shape[0]
        max_seq_length = np.max(sequence_length)
        dtype = dtype or reward.dtype
        if discount == 1.0:
            dmat = np.ones([batch_size, max_seq_length], dtype=dtype)
        else:
            steps = np.tile(np.arange(max_seq_length), [batch_size, 1])
            mask = np.asarray(
                steps < (sequence_length - 1)[:, None], dtype=dtype
            )
            # Make each row = [discount, ..., discount, 1, ..., 1]
            dmat = mask * discount + (1 - mask)
            dmat = np.cumprod(dmat[:, ::-1], axis=1)[:, ::-1]
        disc_reward = dmat * reward[:, None]
        disc_reward = mask_sequences(disc_reward, sequence_length, dtype=dtype)
        return disc_reward

    def mask_sequences(sequence, sequence_length, dtype=None, time_major=False):
        sequence = np.array(sequence)
        sequence_length = np.array(sequence_length)
        rank = sequence.ndim
        if rank < 2:
            raise ValueError("`sequence` must be 2D or higher order.")
        batch_size = sequence.shape[0]
        max_time = sequence.shape[1]
        dtype = dtype or sequence.dtype
        if time_major:
            sequence = np.transpose(sequence, axes=[1, 0, 2])
        steps = np.tile(np.arange(max_time), [batch_size, 1])
        mask = np.asarray(steps < sequence_length[:, None], dtype=dtype)
        for _ in range(2, rank):
            mask = np.expand_dims(mask, -1)
        sequence = sequence * mask
        if time_major:
            sequence = np.transpose(sequence, axes=[1, 0, 2])
        return sequence

    samples = np.array(samples)
    sample_length = np.array(sample_length)
    # length reward
    reward = (5 - np.abs(sample_length - 5)).astype("float32")
    # repeat punishment to trapped into local minima getting all same words
    # beam search to get more than one sample may also can avoid this
    for i in range(reward.shape[0]):
        reward[i] += (
            -10
            if sample_length[i] > 1
            and np.all(samples[i][: sample_length[i] - 1] == samples[i][0])
            else 0
        )
    return discount_reward(reward, sample_length, discount=1.0).astype(
        "float32"
    )


class MLE:
    """teacher-forcing MLE training"""

    def __init__(self, lr=None):
        self.lr = lr

    def learn(self, probs, label, weight=None, length=None):
        loss = paddle.nn.functional.cross_entropy(
            input=probs,
            label=label,
            soft_label=False,
            reduction='none',
            use_softmax=False,
        )
        max_seq_len = paddle.shape(probs)[1]
        mask = paddle.static.nn.sequence_lod.sequence_mask(
            length, maxlen=max_seq_len, dtype="float32"
        )
        loss = loss * mask
        loss = paddle.mean(loss, axis=[0])
        loss = paddle.sum(loss)
        optimizer = fluid.optimizer.Adam(self.lr)
        optimizer.minimize(loss)
        return loss


class SeqPGAgent:
    def __init__(
        self,
        model_cls,
        alg_cls=PolicyGradient,
        model_hparams={},
        alg_hparams={},
        executor=None,
        main_program=None,
        startup_program=None,
        seed=None,
    ):
        self.main_program = (
            fluid.Program() if main_program is None else main_program
        )
        self.startup_program = (
            fluid.Program() if startup_program is None else startup_program
        )
        if seed is not None:
            self.main_program.random_seed = seed
            self.startup_program.random_seed = seed
        self.build_program(model_cls, alg_cls, model_hparams, alg_hparams)
        self.executor = executor

    def build_program(self, model_cls, alg_cls, model_hparams, alg_hparams):
        with fluid.program_guard(self.main_program, self.startup_program):
            source = fluid.data(name="src", shape=[None, None], dtype="int64")
            source_length = fluid.data(
                name="src_sequence_length", shape=[None], dtype="int64"
            )
            # only for teacher-forcing MLE training
            target = fluid.data(name="trg", shape=[None, None], dtype="int64")
            target_length = fluid.data(
                name="trg_sequence_length", shape=[None], dtype="int64"
            )
            label = fluid.data(
                name="label", shape=[None, None, 1], dtype="int64"
            )
            self.model = model_cls(**model_hparams)
            self.alg = alg_cls(**alg_hparams)
            self.probs, self.samples, self.sample_length = self.model(
                source, source_length, target, target_length
            )
            self.samples.stop_gradient = True
            self.reward = fluid.data(
                name="reward",
                shape=[None, None],  # batch_size, seq_len
                dtype=self.probs.dtype,
            )
            self.samples.stop_gradient = False
            self.cost = self.alg.learn(
                self.probs, self.samples, self.reward, self.sample_length
            )

        # to define the same parameters between different programs
        self.pred_program = self.main_program._prune_with_input(
            [source.name, source_length.name],
            [self.probs, self.samples, self.sample_length],
        )

    def predict(self, feed_dict):
        samples, sample_length = self.executor.run(
            self.pred_program,
            feed=feed_dict,
            fetch_list=[self.samples, self.sample_length],
        )
        return samples, sample_length

    def learn(self, feed_dict, fetch_list):
        results = self.executor.run(
            self.main_program, feed=feed_dict, fetch_list=fetch_list
        )
        return results


class ModuleApiTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls._np_rand_state = np.random.get_state()
        cls._py_rand_state = random.getstate()
        cls._random_seed = 123
        np.random.seed(cls._random_seed)
        random.seed(cls._random_seed)

        cls.model_cls = type(
            cls.__name__ + "Model",
            (Layer,),
            {
                "__init__": cls.model_init_wrapper(cls.model_init),
                "forward": cls.model_forward,
            },
        )

    @classmethod
    def tearDownClass(cls):
        np.random.set_state(cls._np_rand_state)
        random.setstate(cls._py_rand_state)

    @staticmethod
    def model_init_wrapper(func):
        def __impl__(self, *args, **kwargs):
            Layer.__init__(self)
            func(self, *args, **kwargs)

        return __impl__

    @staticmethod
    def model_init(model, *args, **kwargs):
        raise NotImplementedError(
            "model_init acts as `Model.__init__`, thus must implement it"
        )

    @staticmethod
    def model_forward(model, *args, **kwargs):
        return model.module(*args, **kwargs)

    def make_inputs(self):
        # TODO(guosheng): add default from `self.inputs`
        raise NotImplementedError(
            "model_inputs makes inputs for model, thus must implement it"
        )

    def setUp(self):
        """
        For the model which wraps the module to be tested:
            Set input data by `self.inputs` list
            Set init argument values by `self.attrs` list/dict
            Set model parameter values by `self.param_states` dict
            Set expected output data by `self.outputs` list
        We can create a model instance and run once with these.
        """
        self.inputs = []
        self.attrs = {}
        self.param_states = {}
        self.outputs = []

    def _calc_output(self, place, mode="test", dygraph=True):
        if dygraph:
            fluid.enable_dygraph(place)
        else:
            fluid.disable_dygraph()
        gen = paddle.seed(self._random_seed)
        paddle.framework.random._manual_program_seed(self._random_seed)
        scope = fluid.core.Scope()
        with fluid.scope_guard(scope):
            layer = (
                self.model_cls(**self.attrs)
                if isinstance(self.attrs, dict)
                else self.model_cls(*self.attrs)
            )
            model = Model(layer, inputs=self.make_inputs())
            model.prepare()
            if self.param_states:
                model.load(self.param_states, optim_state=None)
            return model.predict_batch(self.inputs)

    def check_output_with_place(self, place, mode="test"):
        dygraph_output = self._calc_output(place, mode, dygraph=True)
        stgraph_output = self._calc_output(place, mode, dygraph=False)
        expect_output = getattr(self, "outputs", None)
        for actual_t, expect_t in zip(dygraph_output, stgraph_output):
            np.testing.assert_allclose(actual_t, expect_t, rtol=1e-05, atol=0)
        if expect_output:
            for actual_t, expect_t in zip(dygraph_output, expect_output):
                np.testing.assert_allclose(
                    actual_t, expect_t, rtol=1e-05, atol=0
                )

    def check_output(self):
        devices = ["CPU", "GPU"] if fluid.is_compiled_with_cuda() else ["CPU"]
        for device in devices:
            place = set_device(device)
            self.check_output_with_place(place)


class TestBeamSearch(ModuleApiTest):
    def setUp(self):
        paddle.set_default_dtype("float64")
        shape = (8, 32)
        self.inputs = [
            np.random.random(shape).astype("float64"),
            np.random.random(shape).astype("float64"),
        ]
        self.outputs = None
        self.attrs = {
            "vocab_size": 100,
            "embed_dim": 32,
            "hidden_size": 32,
        }
        self.param_states = {}

    @staticmethod
    def model_init(
        self,
        vocab_size,
        embed_dim,
        hidden_size,
        bos_id=0,
        eos_id=1,
        beam_size=4,
        max_step_num=20,
    ):
        embedder = Embedding(vocab_size, embed_dim)
        output_layer = nn.Linear(hidden_size, vocab_size)
        cell = nn.LSTMCell(embed_dim, hidden_size)
        self.max_step_num = max_step_num
        self.beam_search_decoder = BeamSearchDecoder(
            cell,
            start_token=bos_id,
            end_token=eos_id,
            beam_size=beam_size,
            embedding_fn=embedder,
            output_fn=output_layer,
        )

    @staticmethod
    def model_forward(model, init_hidden, init_cell):
        return dynamic_decode(
            model.beam_search_decoder,
            [init_hidden, init_cell],
            max_step_num=model.max_step_num,
            impute_finished=True,
            is_test=True,
        )[0]

    def make_inputs(self):
        inputs = [
            Input([None, self.inputs[0].shape[-1]], "float64", "init_hidden"),
            Input([None, self.inputs[1].shape[-1]], "float64", "init_cell"),
        ]
        return inputs

    def test_check_output(self):
        self.setUp()
        self.make_inputs()
        self.check_output()


class EncoderCell(SimpleRNNCell):
    def __init__(
        self,
        num_layers,
        input_size,
        hidden_size,
        dropout_prob=0.0,
        init_scale=0.1,
    ):
        super(EncoderCell, self).__init__(input_size, hidden_size)
        self.dropout_prob = dropout_prob
        # use add_sublayer to add multi-layers
        self.lstm_cells = []
        for i in range(num_layers):
            self.lstm_cells.append(
                self.add_sublayer(
                    "lstm_%d" % i,
                    LSTMCell(
                        input_size=input_size if i == 0 else hidden_size,
                        hidden_size=hidden_size,
                    ),
                )
            )

    def forward(self, step_input, states):
        new_states = []
        for i, lstm_cell in enumerate(self.lstm_cells):
            out, new_state = lstm_cell(step_input, states[i])
            step_input = (
                layers.dropout(
                    out,
                    self.dropout_prob,
                    dropout_implementation='upscale_in_train',
                )
                if self.dropout_prob > 0
                else out
            )
            new_states.append(new_state)
        return step_input, new_states

    @property
    def state_shape(self):
        return [cell.state_shape for cell in self.lstm_cells]


class Encoder(Layer):
    def __init__(
        self,
        vocab_size,
        embed_dim,
        hidden_size,
        num_layers,
        dropout_prob=0.0,
        init_scale=0.1,
    ):
        super(Encoder, self).__init__()
        self.embedder = Embedding(vocab_size, embed_dim)
        self.stack_lstm = RNN(
            EncoderCell(
                num_layers, embed_dim, hidden_size, dropout_prob, init_scale
            ),
            is_reverse=False,
            time_major=False,
        )

    def forward(self, sequence, sequence_length):
        inputs = self.embedder(sequence)
        encoder_output, encoder_state = self.stack_lstm(
            inputs, sequence_length=sequence_length
        )
        return encoder_output, encoder_state


DecoderCell = EncoderCell


class Decoder(Layer):
    def __init__(
        self,
        vocab_size,
        embed_dim,
        hidden_size,
        num_layers,
        dropout_prob=0.0,
        init_scale=0.1,
    ):
        super(Decoder, self).__init__()
        self.embedder = Embedding(vocab_size, embed_dim)
        self.stack_lstm = RNN(
            DecoderCell(
                num_layers, embed_dim, hidden_size, dropout_prob, init_scale
            ),
            is_reverse=False,
            time_major=False,
        )
        self.output_layer = Linear(hidden_size, vocab_size, bias_attr=False)

    def forward(self, target, decoder_initial_states):
        inputs = self.embedder(target)
        decoder_output, _ = self.stack_lstm(
            inputs, initial_states=decoder_initial_states
        )
        predict = self.output_layer(decoder_output)
        return predict


class TrainingHelper:
    def __init__(self, inputs, sequence_length, time_major=False):
        self.inputs = inputs
        self.sequence_length = sequence_length
        self.time_major = time_major
        self.inputs_ = map_structure(
            lambda x: paddle.nn.functional.pad(
                x,
                pad=([0, 1] + [0, 0] * (len(x.shape) - 1))
                if time_major
                else ([0, 0, 0, 1] + [0, 0] * (len(x.shape) - 2)),
            ),
            self.inputs,
        )

    def initialize(self):
        init_finished = paddle.equal(
            self.sequence_length,
            paddle.full(
                shape=[1], dtype=self.sequence_length.dtype, fill_value=0
            ),
        )
        init_inputs = map_structure(
            lambda x: x[0] if self.time_major else x[:, 0], self.inputs
        )
        return init_inputs, init_finished

    def sample(self, time, outputs, states):
        sample_ids = paddle.argmax(outputs, axis=-1)
        return sample_ids

    def next_inputs(self, time, outputs, states, sample_ids):
        time = (
            paddle.cast(time, "int32")
            if convert_dtype(time.dtype) not in ["int32"]
            else time
        )
        if self.sequence_length.dtype != time.dtype:
            self.sequence_length = paddle.cast(self.sequence_length, time.dtype)
        next_time = time + 1
        finished = paddle.less_equal(self.sequence_length, next_time)

        def _slice(x):
            axes = [0 if self.time_major else 1]
            return paddle.squeeze(
                paddle.slice(
                    x, axes=axes, starts=[next_time], ends=[next_time + 1]
                ),
                axis=axes,
            )

        next_inputs = map_structure(_slice, self.inputs_)
        return finished, next_inputs, states


class BasicDecoder(paddle.nn.decode.Decoder):
    def __init__(self, cell, helper, output_fn=None):
        super().__init__()
        self.cell = cell
        self.helper = helper
        self.output_fn = output_fn

    def initialize(self, initial_cell_states):
        (initial_inputs, initial_finished) = self.helper.initialize()
        return initial_inputs, initial_cell_states, initial_finished

    class OutputWrapper(
        collections.namedtuple("OutputWrapper", ("cell_outputs", "sample_ids"))
    ):
        pass

    def step(self, time, inputs, states, **kwargs):
        cell_outputs, cell_states = self.cell(inputs, states, **kwargs)
        if self.output_fn is not None:
            cell_outputs = self.output_fn(cell_outputs)
        sample_ids = self.helper.sample(
            time=time, outputs=cell_outputs, states=cell_states
        )
        sample_ids.stop_gradient = True
        (finished, next_inputs, next_states) = self.helper.next_inputs(
            time=time,
            outputs=cell_outputs,
            states=cell_states,
            sample_ids=sample_ids,
        )
        outputs = self.OutputWrapper(cell_outputs, sample_ids)
        return (outputs, next_states, next_inputs, finished)


class BaseModel(Layer):
    def __init__(
        self,
        vocab_size=10,
        embed_dim=32,
        hidden_size=32,
        num_layers=1,
        dropout_prob=0.0,
        init_scale=0.1,
    ):
        super(BaseModel, self).__init__()
        self.hidden_size = hidden_size
        self.word_embedding = Embedding(vocab_size, embed_dim)
        self.encoder = Encoder(
            vocab_size,
            embed_dim,
            hidden_size,
            num_layers,
            dropout_prob,
            init_scale,
        )
        self.decoder = Decoder(
            vocab_size,
            embed_dim,
            hidden_size,
            num_layers,
            dropout_prob,
            init_scale,
        )

    def forward(self, src, src_length, trg, trg_length):
        encoder_output = self.encoder(src, src_length)
        trg_emb = self.decoder.embedder(trg)
        helper = TrainingHelper(inputs=trg_emb, sequence_length=trg_length)
        decoder = BasicDecoder(self.decoder.stack_lstm.cell, helper)
        (
            decoder_output,
            decoder_final_state,
            dec_seq_lengths,
        ) = dynamic_decode(
            decoder,
            inits=self.decoder.stack_lstm.cell.get_initial_states(
                encoder_output
            ),
            impute_finished=True,
            is_test=False,
            return_length=True,
        )
        logits, samples, sample_length = (
            decoder_output.cell_outputs,
            decoder_output.sample_ids,
            dec_seq_lengths,
        )
        return logits


class TestDynamicDecode(ModuleApiTest):
    def setUp(self):
        paddle.set_default_dtype("float64")
        shape = (1, 10)
        bs_shape = 1
        self.inputs = [
            np.random.randint(0, 10, size=shape).astype("int64"),
            np.random.randint(0, 10, size=bs_shape).astype("int64"),
            np.random.randint(0, 10, size=shape).astype("int64"),
            np.random.randint(0, 10, size=bs_shape).astype("int64"),
        ]
        self.outputs = None
        self.attrs = {
            "vocab_size": 10,
            "embed_dim": 32,
            "hidden_size": 32,
        }
        self.param_states = {}

    @staticmethod
    def model_init(
        self,
        vocab_size,
        embed_dim,
        hidden_size,
        bos_id=0,
        eos_id=1,
    ):
        self.model = BaseModel(
            vocab_size=vocab_size, embed_dim=embed_dim, hidden_size=hidden_size
        )

    @staticmethod
    def model_forward(model, src, src_length, trg, trg_length):
        return model.model(src, src_length, trg, trg_length)

    def make_inputs(self):
        inputs = [
            Input([None, None], "int64", "src"),
            Input([None], "int64", "src_length"),
            Input([None, None], "int64", "trg"),
            Input([None], "int64", "trg_length"),
        ]
        return inputs

    def test_check_output(self):
        self.setUp()
        self.make_inputs()
        self.check_output()


if __name__ == '__main__':
    unittest.main()
